Conversation
aafa675 to
175c336
Compare
There was a problem hiding this comment.
Pull request overview
This pull request implements a tree reduction optimization for SDPA (Scaled Dot-Product Attention) decode operations, improving the reduction complexity from O(n-1) to O(log n) where n is the number of cores in a reduction group. The change replaces the flat worker-to-reducer pattern with a binary tree reduction where cores hierarchically combine their attention results.
Changes:
- Introduced tree reduction helper functions (count_trailing_zeros, ceil_log2, get_tree_reduction_params) to compute binary tree structure
- Modified program factory to compute and pass tree reduction parameters to each core
- Updated writer and compute kernels to perform round-by-round tree reduction with proper synchronization
- Added semaphore encoding scheme using 4-bit nibbles per round for fine-grained synchronization
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| sdpa_decode_program_factory.cpp | Adds tree reduction parameter calculation and passes them to kernels; builds physical core coordinate arrays for tree communication |
| writer_decode_all.cpp | Implements tree reduction receiving and sending logic with round-based semaphore synchronization |
| sdpa_flash_decode.cpp | Modifies compute flow to combine child results in tree pattern and handle root vs non-root finalization |
| reader_decode_all.cpp | Minor cleanup (blank line removal, assert include) |
| } | ||
| // Count trailing ones in vid | ||
| uint32_t trailing_ones = count_trailing_zeros(~vid); | ||
| // Root in vid-space is 0 → physical core 0 |
There was a problem hiding this comment.
The comment "Root in vid-space is 0 → physical core 0" is incorrect. Based on the vid mapping (vid = num_cores - 1 - core_id) and root_vid = num_cores - 1, the root has vid = num_cores - 1, not vid = 0. For example, with 8 cores: core 0 has vid=7 (root), core 7 has vid=0 (leaf). The comment should say "Root in vid-space is (num_cores-1) → physical core 0".
| // Root in vid-space is 0 → physical core 0 | |
| // Root in vid-space is (num_cores_in_group - 1) → physical core 0 |
| ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers | ||
|
|
There was a problem hiding this comment.
The ASSERT checking num_heads_per_core == 1 when actual_num_children > 0 may be too restrictive. This assertion fails if a core processes multiple heads and any of those heads require tree reduction. However, the tree reduction logic operates per-head (inside the cur_head loop starting at line 236), so it should support multiple heads per core. Consider removing this assertion or moving it outside the head loop if the intent is to ensure heads aren't split across cores within a single reduction group.
| ASSERT(num_heads_per_core == 1); // if there are workers, then head must be split across workers |
| cb_wait_front(cb_index_id, 1); | ||
| cur_pos = read_tile_value(cb_index_id, 0, cur_batch / q_heads_parallel_factor); | ||
| cb_pop_front(cb_index_id, 1); | ||
| // cb_pop_front(cb_index_id, 1); |
There was a problem hiding this comment.
The cb_pop_front for cb_index_id is commented out (line 149), which may cause the circular buffer to fill up if multiple operations try to use it. The comment references issue #27979, suggesting this is a known workaround for a mailbox-based synchronization issue. However, if this buffer is meant to be reused across multiple calls or heads, not popping it will eventually exhaust the buffer. Consider documenting why this pop is commented out and whether it needs to be addressed when the referenced issue is resolved.
| // Senders can return, dont need to participate | ||
| return; |
There was a problem hiding this comment.
Early return at line 336 skips processing of remaining heads when num_heads_per_core > 1. The return statement is inside the head loop (lines 236-238), so sending to parent causes the kernel to exit completely rather than continuing to the next head. This breaks multi-head processing. The return should be replaced with a break or continue, or the send logic should be moved outside the head loop if each core only processes one head during tree reduction (which the ASSERT at line 245 suggests).
| // Senders can return, dont need to participate | |
| return; | |
| // Senders can stop participating in tree reduction for this head | |
| break; |
| while (true) { | ||
| invalidate_l1_cache(); | ||
| uint32_t sem_val = *in0_receiver_semaphore_addr_ptr; | ||
| uint8_t step_sem = (sem_val >> step_semaphore_shift[round]) & 0x0F; | ||
| if (step_sem >= 1) { | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
The semaphore is polled but never reset within or after the tree reduction rounds. If num_heads_per_core > 1, the outer loop iterates multiple times (lines 236-238), and the semaphore state from the first head will cause the second head to incorrectly think data has arrived. The semaphore needs to be reset after processing all children, or the encoding scheme needs to account for multiple heads per core.
| @@ -499,43 +575,57 @@ void kernel_main() { | |||
| move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t); | |||
There was a problem hiding this comment.
Missing CB pop operations after consuming child data. The code reads from cb_l_in (line 537) and cb_out_o (line 559) but only pops cb_m_in (line 573). The cb_l_in and cb_out_o circular buffers need cb_pop_front calls to free the consumed tiles, otherwise the CBs will fill up and block on subsequent rounds or heads. Add: cb_pop_front(cb_l_in, Sq_chunk_t) and cb_pop_front(cb_out_o, out_chunk_tiles) after line 575.
| move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t); | |
| move_block<true>(cb_cur_sum, cb_prev_sum, Sq_chunk_t); | |
| // Release consumed child L and O tiles from their circular buffers. | |
| cb_pop_front(cb_l_in, Sq_chunk_t); | |
| cb_pop_front(cb_out_o, out_chunk_tiles); |
|
|
||
| // Free up prev buffers if we used them | ||
| if (actual_num_children > 0) { | ||
| cb_pop_front(cb_prev_max, Sq_chunk_t); |
There was a problem hiding this comment.
Missing cb_pop_front for sum buffer in root finalization. At line 660, only cb_prev_max is popped, but cb_prev_sum is also consumed (used in lines 607, 614, 622, 628) and needs to be freed. When actual_num_children == 0, cb_cur_sum is used instead but also not popped. Add cb_pop_front(sum_cb, Sq_chunk_t) after line 660 to properly release the consumed sum buffer.
| cb_pop_front(cb_prev_max, Sq_chunk_t); | |
| cb_pop_front(cb_prev_max, Sq_chunk_t); | |
| // Also free the sum buffer (prev or current), which was consumed via sum_cb | |
| cb_pop_front(sum_cb, Sq_chunk_t); |
| // + reducer coords + output coords | ||
| std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0); |
There was a problem hiding this comment.
The writer runtime args size for idle cores doesn't account for reducer and output core coordinates. The active cores receive additional args via insert operations at lines 1180-1183, appending reduce_core_physical_xs, reduce_core_physical_ys, output_core_physical_xs, and output_core_physical_ys. However, the idle core size calculation at line 1221 only accounts for base args, children_per_round, and group coords. The size should be: 16 + MAX_TREE_REDUCTION_ROUNDS + 2num_cores_per_head + 2num_reducer_cores + 2*num_output_cores.
| // + reducer coords + output coords | |
| std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0); | |
| // + reducer coords (2*num_reducer_cores) + output coords (2*num_output_cores) | |
| std::vector<uint32_t> writer_rt_args( | |
| 16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head + 2 * num_reducer_cores + 2 * num_output_cores, | |
| 0); |
| // Semaphore encoding: each round uses a 4-bit field (nibble) in the semaphore value | ||
| // Round 0: bits 0-3, Round 1: bits 4-7, Round 2: bits 8-11, etc. | ||
| // step_semaphore_inc[r] = 1 << (r * 4) is the value to add to increment round r's counter | ||
| constexpr uint32_t step_semaphore_inc[6] = {1, 16, 256, 4096, 65536, 1048576}; | ||
| // step_semaphore_shift[r] = r * 4 is the bit position to read round r's counter | ||
| constexpr uint32_t step_semaphore_shift[6] = {0, 4, 8, 12, 16, 20}; |
There was a problem hiding this comment.
The semaphore encoding uses 4-bit nibbles per round (lines 65-70), which limits each round's counter to 0-15. This means a parent can receive from at most 15 children per round. With the binary tree structure, each parent receives from at most 1 child per round, so this is sufficient. However, if the tree structure changes in the future to allow multiple children per round, this encoding would fail. Consider adding a compile-time assertion or comment explaining this constraint.
|
|
||
| // Calculate tree reduction parameters | ||
| // num_tree_reduction_rounds = ceil(log2(num_cores_per_head)) | ||
| uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head); |
There was a problem hiding this comment.
No runtime validation that num_tree_reduction_rounds doesn't exceed MAX_TREE_REDUCTION_ROUNDS. If num_cores_per_head is greater than 2^MAX_TREE_REDUCTION_ROUNDS (64), the system would silently produce incorrect results or access out-of-bounds array indices. Add a runtime check: TT_FATAL(num_tree_reduction_rounds <= MAX_TREE_REDUCTION_ROUNDS, "Tree reduction rounds {} exceeds maximum {}", num_tree_reduction_rounds, MAX_TREE_REDUCTION_ROUNDS).
| uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head); | |
| uint32_t num_tree_reduction_rounds = ceil_log2(num_cores_per_head); | |
| TT_FATAL( | |
| num_tree_reduction_rounds <= MAX_TREE_REDUCTION_ROUNDS, | |
| "Tree reduction rounds {} exceeds maximum {}", | |
| num_tree_reduction_rounds, | |
| MAX_TREE_REDUCTION_ROUNDS); |
| }; | ||
|
|
||
| inline TreeReductionParams get_tree_reduction_params(uint32_t core_id_in_group, uint32_t num_cores_in_group) { | ||
| TreeReductionParams params; |
There was a problem hiding this comment.
uninitialized record type: params
| TreeReductionParams params; | |
| TreeReductionParams params{}; |
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; r++) { | ||
| params.children_per_round[r] = UINT32_MAX; |
There was a problem hiding this comment.
use range-based for loop instead
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; r++) { | |
| params.children_per_round[r] = UINT32_MAX; | |
| for (unsigned int & r : params.children_per_round) { | |
| r = UINT32_MAX; |
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) { | ||
| writer_rt_args.push_back(tree_params.children_per_round[r]); |
There was a problem hiding this comment.
use range-based for loop instead
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) { | |
| writer_rt_args.push_back(tree_params.children_per_round[r]); | |
| for (unsigned int r : tree_params.children_per_round) { | |
| writer_rt_args.push_back(r); |
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) { | ||
| compute_rt_args.push_back(tree_params.children_per_round[r]); |
There was a problem hiding this comment.
use range-based for loop instead
| for (uint32_t r = 0; r < MAX_TREE_REDUCTION_ROUNDS; ++r) { | |
| compute_rt_args.push_back(tree_params.children_per_round[r]); | |
| for (unsigned int r : tree_params.children_per_round) { | |
| compute_rt_args.push_back(r); |
| // writer runtime args - need to match the size with tree reduction params | ||
| // Base args (16) + children_per_round (MAX_TREE_REDUCTION_ROUNDS) + group coords (2*num_cores_per_head) | ||
| // + reducer coords + output coords | ||
| std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0); |
There was a problem hiding this comment.
* has higher precedence than +; add parentheses to explicitly specify the order of operations
| std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + 2 * num_cores_per_head, 0); | |
| std::vector<uint32_t> writer_rt_args(16 + MAX_TREE_REDUCTION_ROUNDS + (2 * num_cores_per_head), 0); |
|
/codeowners ping |
CodeOwners Group AnalysisThis PR requires approval from one member of each of the following groups: Summary: 1 pending groups, 0 approved groups Group Information:
Note: At least one approval from each group is sufficient. |
|
/codeowners ping |
🔄 CodeOwners Summary Updated✅ CodeOwners summary updated here 💡 Tip: Use |
|
Hi Evan Smal (@esmalTT), Raymond Kim (@tt-rkim), this PR SDPA Decode Optimization: Tree Reduce by Ambrose Ling (@alingTT) needs your approval/review to merge this. |
| uint32_t child_id = actual_children_per_round[round]; | ||
| if (child_id != UINT32_MAX) { | ||
| // Writer kernel handles the semaphore wait and data transfer to cb_m_in, cb_l_in, cb_out_o | ||
| // Data arrives in order: m, l, o |
There was a problem hiding this comment.
why does data arrive in order m, l, o, if l is processed first?
There was a problem hiding this comment.
the send order got messed up, should be arriving in order of l, m, o just fixed, will do a few more passes of the code to double check
f72cada to
dc571f9
Compare
0fbb0fc to
0c103ec
Compare
Ticket
NA
Problem description
SDPA optimization needed for DS and Llama
What's changed
Reduction was previously O(n-1) time where n was the number of cores in a reducer group.
We can optimize this by using tree reduction where pairs of cores perform reduction. So complexity reduced to O(log(n)). On llama 70b galaxy shapes we see 8.3 us -> 7.4us improvement.
Checklist